function out = fct_EM_conditional(EMinput)

% this function uses an guess of the estimates of our model to  improve its fit using the EM
% algorithm. This file used the conditional distribution, conditioning on
% (t1,t2) geq TL and (t1,t2) leq TU
% this version uses an iterative procedure to solve the ML in each of the M
% steps
%

mm_mat = fieldnames( EMinput );
for i = 1 : length(mm_mat)
    eval([cell2mat(mm_mat(i)) '= EMinput.(cell2mat(mm_mat(i)));']);
end

% initial guess, update in the loop
    g_vec      = g_vec_ini;
    mu_vec     = mu_vec_ini;
    sigma_vec  = sigma_vec_ini;
%
%%% raw symmetric_data has three columns
%%% column 1 and 2 have the weeks t1 and t2
%%% colum 3 has the number of spells with that combination of (t1,t2)
%
MM = max(size(data_EM));
Npeople = sum(data_EM(:,3));
K = length(g_vec);
%
%%%% THIS IS SET OUTSIDE ==================================================
%n_EM  = 300;    % maximum number of steps on the EM algorithm
%tol   = 10^(-8); % tolerance on maximum difference between iteratins of EM
diff  = tol+1; % initial difference
i_EM = 1;      % intial index of EM steps
max_iter = 1;  % counts iterations until full maximization
log_like_max = -Inf;   % keep track of the highest achieved log-likelihood
%
% Minimum values for parameters
%g_min       = 1e-20 ;     
%sigma_min   = 0.00001;  % this is needed so that f > 0  
%z_min       = 1e-300;
%
% parameters used in the MLE proceduere of the  step 
%Niter_MLE = 5000 ;  % maximum number of iterations for each type k
%toler_MLE = 10^(-7); % tolerance (in percentage) to stop iterations for each type k
%
%step = 1/2;
%%%% THIS IS SET OUTSIDE ==================================================

% ORIGINAL LINES
t1_vec = data_EM(:,1)+tstep; % vector of times t1, 
t2_vec = data_EM(:,2)+tstep; % vector of times t2
%it adds step  because probability is during the week that starts at
% data_EM(:,1) or data_EM(:,1)
%
% t1veclog and t2veclog add 1/2 step, so that the minimum value is 3/2 step
% this avoids taking the log of zero in one version where we want to correc the Jensen's inequality
% effec of estimating E[1/t] using discretize data. 
% we use this when use_Elog =  1. 
t1_veclog = t1_vec;
t1_veclog = max(t1_vec,3/2*tstep); 
t2_veclog = t2_vec;
t2_veclog = max(t2_vec,3/2*tstep);

% NEW PART: log-likelihood tends to be non-monotone and typically stops at
% n_EM. Hence, we store the last k iterations and choose the one with
% the highest log-likelihood.
%choose_the_highest = 1;
%tokeep = 20;
  kb = 0;
  mu_vec_aux = zeros(K,tokeep);
  sigma_vec_aux= zeros(K,tokeep);
  g_vec_aux = zeros(K,tokeep);
  log_like_aux= zeros(1,tokeep);
  log_like_partial_aux = zeros(1,tokeep);
  log_like_pure_EM_aux= zeros(1,tokeep);
   
  mu_vec_iter = zeros(K,n_EM);
  sigma_vec_iter = zeros(K,n_EM);
  g_vec_iter = zeros(K,n_EM);
    
%
%next loop implements the EM alogorithm
%
while i_EM < n_EM + 1 && diff > tol
    
    mu_vec_iter(:,i_EM) = mu_vec;
    sigma_vec_iter(:,i_EM) = sigma_vec;
    g_vec_iter(:,i_EM) = g_vec;   

    %
    %
    if print_iter == 1 
       disp(' ') 
       disp(['EM iteration = ', num2str(i_EM)]);
    end
    %%%% E step: updates share z
    %
    z=zeros(K,MM);
    % computes z(k,j) = share of observations with pairs (t1,t2) indexed by
    % j that comes from type j
    % k type in 1,..,K and j is a combination of spells (t1,t2) out of MM.
    for k=1:K 
        mu     =  mu_vec(k,1);
        sigma  =  sigma_vec(k,1);
        %
         Flow  = fct_CDF_F_capped(TL,mu,sigma)  ;
         Fhigh = fct_CDF_F_capped(TU,mu,sigma) ;
%         %
        Flow_vec(k,1)  = Flow  ;
        Fhigh_vec(k,1) = Fhigh ;
         if Fhigh > Flow
            f1 = fct_pdf_f(t1_vec,mu,sigma);
            f2 = fct_pdf_f(t2_vec,mu,sigma);
            z(k,1:MM) = f1 .* f2 .* g_vec(k,1) ;
        else
            z(k,1:MM) = ones(1,MM)*z_min ;
        end
        % z(k, t1,t2) = f(t1,mu_k,sigma_k) f(t2,mu_k,sigma_k) /[1-F(TL,mu_k,sigma_k)]^2 g(k) (*)
        
        % define y(k)
        y(k) = (Fhigh-Flow)^2*g_vec(k,1);
    end
    

    y  = y/sum(y);      % enforces that y are in the simplex

        
    % log-like is the log-likelihood on step i_EM
    % not this is the log-likelihood of all K types
    log_like(i_EM) = 0;
    for j = 1:MM
        log_like(i_EM) = data_EM(j,3)*log(sum(z(:,j))) + log_like(i_EM);
        z(:,j) = z(:,j)/ sum(z(:,j)); % define z as a share of spells (t1,t2) across types k
        % this line makes sure that z as defined in (*) adds to one across
        % types k for each (t1,t2).
        % so z(k, t1,t2) is the fraction of (t1,t2) that comes from type k
    end
    %dF2             = (Fhigh_vec-Flow_vec).^2;
    dF2             = (Fhigh_vec).^2;
    log_like_partial(i_EM) = log_like(i_EM);
    log_like(i_EM)  = log_like(i_EM) - Npeople* log(dF2'*g_vec); 
    %
    %
    %%%% M step: find  mu,sigma using ML for each k:
    %
    %

    %
    %%%%  updates mu(k), sigma(k) estimates, using MLE for each type k:
    %
    % these lines use the mean of t and mean of 1/t.
    % each expectation is taken with respect to the weighted distribution,
    % usign z
    %
    % E_t(k) = E[ t | type k ] and E_1over_t(k) = E[ 1/t | type k]
    %
    % since they are conditioning on coming from type k, we weight the
    % distribution of (t1,t2) in the data by z(k,t1,t2)
    %
    for k=1:K
        E_t(k) =  0.5*(t1_vec+t2_vec)' * ...
            (z(k,:)' .* data_EM(:,3))/sum(z(k,:)' .* data_EM(:,3));
        
        %
        if use_Elog == 1
            
            E_1_over_t(k) =  0.5 * ( log(t1_veclog-tstep+1)-log(t1_veclog-tstep) + log(t2_veclog-tstep+1)-log(t2_veclog-tstep) )' * ...
            (z(k,:)' .* data_EM(:,3)) / sum( z(k,:)' .* data_EM(:,3));  
                   
        else   
            E_1_over_t(k) =  0.5 * (1./t1_vec + 1./t2_vec )' * ...
            (z(k,:)' .* data_EM(:,3)) / sum( z(k,:)' .* data_EM(:,3));
        end
    end
    

    % 
    % the weight on (t1,t2) is  z(k,t1,t2) phi(t1,t2) / [sum{t1',t2'} z(k,t1',t2') * phi(t1',t2') ] 
    %
    % the next loop implements the ML estimates conditional on (t1,t2) geq TL
    %
    %
    iter_vec=zeros(1,K);
    %
    %%%%%disp('Step 2: updating mu and sigma')
    for k=1:K
        %
        Et = E_t(k);
        E1ot = E_1_over_t(k);
        if Et == 0
        disp('Et is zero for k %d',k)
        end
        %
        mu_0    = 1/Et;
        sigma_0 = sqrt(max(mu_0^2*Et + E1ot - 2*mu_0 ,1e-100) );
        %
        iter=1;
        dist = toler_MLE*2;
        mu=mu_0;
        sigma=sigma_0;
        %
        %%% UPDATE THIS TO REFLECT THE NEW PROCEDURE

        while iter < Niter_MLE & dist > toler_MLE
            
            %iter
            %mu,sigma 
            
            [F_low Fmu_low Fsigma_low]    = fct_CDF_F_capped_prime(TL,mu,sigma);
            [F_high Fmu_high Fsigma_high] = fct_CDF_F_capped_prime(TU,mu,sigma);

            Dmu_term    = Fmu_high/F_high;
            Dsigma_term = Fsigma_high/F_high;
            
            if F_high > F_low
            % f.o.c. w.r.t mu:
            next_mu = (1 - sigma^2*y(k)*Dmu_term)/Et;
            
            next_mu = max(mu_min, next_mu) ; % enforces the minimum

            % f.o.c. w.r.t sigma:
            %next_sigma = sqrt(max(mu^2*Et + E1ot - 2*mu + sigma^3*(Fsigma_low-Fsigma_high)/(F_high-F_low),1e-100));
            next_sigma  = sqrt(max(next_mu^2*Et + E1ot - 2*next_mu - sigma^3*y(k)*Dsigma_term, 1e-100));
            else
                next_mu = mu;
                next_sigma = sigma;
                disp(' Fhigh = Flow ')
            end
            
            % computes distance:
            dist = max(abs(next_mu - mu), abs(next_sigma - sigma));
            
            %sigma_ML_nocond(k) = sqrt(1/Et-2*Et +E1ot);
            mu_pure_EM(k)    = 1/Et;
            sigma_pure_EM(k) = sqrt(max(mu^2*Et + E1ot - 2*mu,1e-100));
            %%%%% disp(sprintf('k %d, iter %d mu:    old, new, 1/Et: %7.5f,  %7.5f  %7.5f ',k,iter, mu,next_mu,1/Et))
            %%%%% disp(sprintf('k %d, iter %d sigma: old, new, ML estimate: %7.5f,  %7.5f  %7.5f ',k,iter, sigma, next_sigma,sigma_pure_EM(k) ))
            % update estiamtes 
            mu    = next_mu;  
            sigma = next_sigma;
            iter = iter+1;
            
        end
        
        % keep of track of how many iterations where used for this value of k:
        iter_vec(k)=iter;
        %
%         % use non-linear solver for foc's:
%         mu_sigma = [mu,sigma];
%         
%         % check input into fsolve
%         %disp(sprintf('Iter %d, mu %6.4f sig %6.4f Et %7.4f E1t %7.4f',iter,mu_sigma(1),mu_sigma(2),Et,E1ot))
%                     
%         [ mu_sigma_sol zeros_sol exitflag_sol] = fsolve(@(mu_sigma) fct_obj_MLE_cond(mu_sigma,TL,TU,Et,E1ot) , [mu, sigma] , optimset('TolFun',1e-15,'TolX',1e-15,'Display','off') );
%         mu= mu_sigma_sol(1);
%         sigma=mu_sigma_sol(2);
        
        % Updates estimates for type k:
        mu_new(k,1)    = max(mu_min, mu) ;
        sigma_new(k,1) = sigma ; 
        %       
    end
    
    %%%%  updates g estimates
    % iterative scheme
    i_g = 0;
    g_diff = 10;
    g_tol = 10^(-5);
    g_iter_max = 100;
    g_iter_old = g_vec;
    Fhigh_giter  = zeros(K,1);
    
    %%%%%disp('Step 3: updating g')
    
    while i_g < g_iter_max  & g_diff > g_tol;
        
        if i_g == 0
            for k = 1:K
            mu     =  mu_new(k,1);
            sigma  =  sigma_new(k,1);
            f1 = fct_pdf_f(t1_vec,mu,sigma);
            f2 = fct_pdf_f(t2_vec,mu,sigma);
            z(k,1:MM) = f1 .* f2 .* g_iter_old(k,1) ;
            end
            
            log_like_pure_EM(i_EM) = 0;
            for j = 1:MM
            log_like_pure_EM(i_EM) = data_EM(j,3)*log(sum(z(:,j))) + log_like(i_EM);
            end
    
            zsum = sum(z);
            z = z./repmat(zsum,K,1);
            
            g_pure_EM  = z*data_EM(:,3)/Npeople;
        end
        
        % contruct z using new mu, sigma
        for k = 1:K
            %mu     =  mu_vec(k,1);
            %sigma  =  sigma_vec(k,1);
            mu     =  mu_new(k,1);
            sigma  =  sigma_new(k,1);
            f1 = fct_pdf_f(t1_vec,mu,sigma);
            f2 = fct_pdf_f(t2_vec,mu,sigma);
            z(k,1:MM) = f1 .* f2 .* g_iter_old(k,1) ;
            
            Fhigh_giter(k) =  fct_CDF_F_capped_prime(TU,mu,sigma);
        end
        % normalize
        zsum = sum(z);
        z = z./repmat(zsum,K,1);
        
        sumy = (Fhigh_giter.^2)'*g_iter_old;
%         size(z)
%         size(zsum )
%         pause
        % new value of g
        z_data      = z*data_EM(:,3)/Npeople;
        
        g_iter_new  =  z_data*sumy./(Fhigh_giter.^2);
        
        g_iter_new = g_iter_new/sum(g_iter_new); % this lines makes sure g_new defined in (**) adds up to one across k
        g_iter_new = max(g_min,g_iter_new)/sum(max(g_min,g_iter_new)); % enforce a strictly positive minimum of g 
                
        g_diff = max(abs(g_iter_new-g_iter_old));
        g_diff_vec(i_g+1) = g_diff;

        g_iter_old = g_iter_new;
        i_g = i_g+1;        
    end  

    
    g_new = g_iter_new;
    %disp(['mean number of iterations = ' num2str(mean(iter_vec),3)]);      
    %
    % upadtes step of EM algorithm
    i_EM = i_EM + 1 ;
    % computes differences between successive iterations
    diff_mu    = max(abs(mu_vec-mu_new)) ;
    diff_sigma = max(abs(sigma_vec-sigma_new)) ;
    diff_g     = max(abs(g_vec-g_new)) ;
    diff       = max(diff_mu,max(diff_sigma,diff_g));
    diff_vec(i_EM) = diff;
    diff_mu_vec(i_EM)    = diff_mu;
    diff_sigma_vec(i_EM) = diff_sigma;
    diff_g_vec(i_EM)     = diff_g;
    
    % updates mu, sigma, g for next iteration:        
    mu_vec    = mu_new;
    sigma_vec = max(sigma_new,sigma_min) ;
    g_vec     = g_new ;
    
    %
    
    if i_EM-1 > n_EM - tokeep
        kb = kb+1;
        mu_vec_aux(:,kb) = mu_vec;
        sigma_vec_aux(:,kb) = sigma_vec;
        g_vec_aux(:,kb) = g_vec;
        log_like_aux(kb) =  log_like(i_EM-1);
        log_like_partial_aux(kb) = log_like_partial(i_EM-1);
        log_like_pure_EM_aux(kb) = log_like_pure_EM(i_EM-1);
    end
    %pause
    
    % remember the highest achieved max-log-likelihood
    if log_like(i_EM-1) > log_like_max
        mu_vec_max    = mu_vec;
        sigma_vec_max = sigma_vec ;
        g_vec_max     = g_vec ;
        log_like_max  = log_like(i_EM-1);   
        log_like_partial_max = log_like_partial(i_EM-1);
        log_like_pure_EM_max = log_like_pure_EM(i_EM-1);
    end
   
end

    if choose_the_highest == 1
       %log_like_aux(1:20) = min(log_like_aux);
       [m1 m2] = max(log_like_aux);

       mu_vec = mu_vec_aux(:,m2);
       sigma_vec = sigma_vec_aux(:,m2);
       g_vec = g_vec_aux(:,m2);

       log_like(i_EM-1) = log_like_aux(m2);
       log_like_partial(i_EM-1) = log_like_partial_aux(m2);
       log_like_pure_EM(i_EM-1) = log_like_pure_EM_aux(m2);
    end
    
    out.g_vec = g_vec;
    out.mu_vec = mu_vec;
    out.sigma_vec = sigma_vec;
    out.log_like = log_like; 
    out.log_like_partial = log_like_partial; 
    out.log_like_pure_EM = log_like_pure_EM;
    out.mu_vec_max    = mu_vec_max;
    out.sigma_vec_max = sigma_vec_max ;
    out.g_vec_max     = g_vec_max ;
    out.log_like_max  = log_like_max;     
    out.log_like_partial_max = log_like_partial_max;
    out.log_like_pure_EM_max = log_like_pure_EM_max;
        
    out.mu_vec_iter = mu_vec_iter;
    out.sigma_vec_iter = sigma_vec_iter;
    out.g_vec_iter = g_vec_iter;
    out.diff_vec = diff_vec;
    out.diff_mu_vec = diff_mu_vec;
    out.diff_sigma_vec = diff_sigma_vec;
    out.diff_g_vec = diff_g_vec;
    
end
